{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we are going to evaluate the performance of the naive RAG and the GraphRAG algorithm on a [multi-hop RAG task](https://github.com/yixuantt/MultiHop-RAG)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "Make sure you install the necessary dependencies by running the following commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install ragas nest_asyncio datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import the necessary libraries, and set up your openai api key if needed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n",
    "import json\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import nest_asyncio\n",
    "nest_asyncio.apply()\n",
    "import logging\n",
    "\n",
    "logging.basicConfig(level=logging.WARNING)\n",
    "logging.getLogger(\"nano-graphrag\").setLevel(logging.INFO)\n",
    "from nano_graphrag import GraphRAG, QueryParam\n",
    "from datasets import Dataset \n",
    "from ragas import evaluate\n",
    "from ragas.metrics import (\n",
    "    answer_correctness,\n",
    "    answer_similarity,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the dataset from [Github Repo](https://github.com/yixuantt/MultiHop-RAG/tree/main/dataset). \n",
    "If should contain two files:\n",
    "- `MultiHopRAG.json`\n",
    "- `corpus.json`\n",
    "\n",
    "After downloading the dataset, replace the below paths to the paths on your machine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "multi_hop_rag_file = \"./fixtures/MultiHopRAG.json\"\n",
    "multi_hop_corpus_file = \"./fixtures/corpus.json\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "with open(multi_hop_rag_file) as f:\n",
    "    multi_hop_rag_dataset = json.load(f)\n",
    "with open(multi_hop_corpus_file) as f:\n",
    "    multi_hop_corpus = json.load(f)\n",
    "\n",
    "corups_url_refernces = {}\n",
    "for cor in multi_hop_corpus:\n",
    "    corups_url_refernces[cor['url']] = cor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We only use the top-100 queries for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Queries have types: {'inference_query', 'comparison_query', 'null_query', 'temporal_query'}\n",
      "We will need 139 articles:\n",
      "## ASX set to drop as Wall Street’s September slump deepens\n",
      "Author: Stan Choe, The Sydney Morning Herald\n",
      "Category: business\n",
      "Publised: 2023-09-26T19:11:30+00:00\n",
      "ETF provider Betashares, which manages $ ...\n"
     ]
    }
   ],
   "source": [
    "multi_hop_rag_dataset = multi_hop_rag_dataset[:100]\n",
    "print(\"Queries have types:\", set([q['question_type'] for q in multi_hop_rag_dataset]))\n",
    "total_urls = set()\n",
    "for q in multi_hop_rag_dataset:\n",
    "    total_urls.update([up['url'] for up in q['evidence_list']])\n",
    "corups_url_refernces = {k:v for k, v in corups_url_refernces.items() if k in total_urls}\n",
    "\n",
    "total_corups = [f\"## {cor['title']}\\nAuthor: {cor['author']}, {cor['source']}\\nCategory: {cor['category']}\\nPublised: {cor['published_at']}\\n{cor['body']}\" for cor in corups_url_refernces.values()]\n",
    "\n",
    "print(f\"We will need {len(total_corups)} articles:\")\n",
    "print(total_corups[0][:200], \"...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add index for the `total_corups` using naive RAG and GraphRAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:nano-graphrag:Load KV full_docs with 139 data\n",
      "INFO:nano-graphrag:Load KV text_chunks with 408 data\n",
      "INFO:nano-graphrag:Load KV llm_response_cache with 1634 data\n",
      "INFO:nano-graphrag:Load KV community_reports with 794 data\n",
      "INFO:nano-graphrag:Loaded graph from nano_graphrag_cache_multi_hop_rag_test/graph_chunk_entity_relation.graphml with 6181 nodes, 5423 edges\n",
      "WARNING:nano-graphrag:All docs are already in the storage\n",
      "INFO:nano-graphrag:Writing graph with 6181 nodes, 5423 edges\n"
     ]
    }
   ],
   "source": [
    "# First time indexing will cost many time, roughly 15~20 minutes\n",
    "graphrag_func = GraphRAG(working_dir=\"nano_graphrag_cache_multi_hop_rag_test\", enable_naive_rag=True,\n",
    "                         embedding_func_max_async=4)\n",
    "graphrag_func.insert(total_corups)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at the response of different RAG methods on the first query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "response_formate = \"Single phrase or sentence, concise and no redundant explanation needed. If you don't have the answer in context, Just response 'Insufficient information'\"\n",
    "naive_rag_query_param = QueryParam(mode='naive', response_type=response_formate)\n",
    "naive_rag_query_only_context_param = QueryParam(mode='naive', only_need_context=True)\n",
    "local_graphrag_query_param = QueryParam(mode='local', response_type=response_formate)\n",
    "local_graphrag_only_context__param = QueryParam(mode='local', only_need_context=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: Who is the individual associated with the cryptocurrency industry facing a criminal trial on fraud and conspiracy charges, as reported by both The Verge and TechCrunch, and is accused by prosecutors of committing fraud for personal gain?\n",
      "GroundTruth Answer: Sam Bankman-Fried\n"
     ]
    }
   ],
   "source": [
    "query = multi_hop_rag_dataset[0]\n",
    "print(\"Question:\", query['query'])\n",
    "print(\"GroundTruth Answer:\", query['answer'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:nano-graphrag:Truncate 20 to 12 chunks\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NaiveRAG Answer: Sam Bankman-Fried\n"
     ]
    }
   ],
   "source": [
    "print(\"NaiveRAG Answer:\", graphrag_func.query(query['query'], param=naive_rag_query_param))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:nano-graphrag:Using 20 entites, 3 communities, 124 relations, 3 text units\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Local GraphRAG Answer: Sam Bankman-Fried\n"
     ]
    }
   ],
   "source": [
    "print(\"Local GraphRAG Answer:\", graphrag_func.query(query['query'], param=local_graphrag_query_param))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Now we're ready to evaluate more detailed metrics. We will use [ragas](https://docs.ragas.io/en/stable/) to evalue the answers' quality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = [q['query'] for q in multi_hop_rag_dataset]\n",
    "labels = [q['answer'] for q in multi_hop_rag_dataset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [03:53<00:00,  2.33s/it]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "logging.getLogger(\"nano-graphrag\").setLevel(logging.WARNING)\n",
    "\n",
    "naive_rag_answers = [\n",
    "    graphrag_func.query(q, param=naive_rag_query_param) for q in tqdm(questions)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [09:10<00:00,  5.50s/it]\n"
     ]
    }
   ],
   "source": [
    "local_graphrag_answers = [\n",
    "    graphrag_func.query(q, param=local_graphrag_query_param) for q in tqdm(questions)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 70/100 [04:25<01:53,  3.79s/it]8,  6.38it/s]\n",
      "Evaluating: 100%|██████████| 200/200 [00:32<00:00,  6.19it/s]\n"
     ]
    }
   ],
   "source": [
    "naive_results = evaluate(\n",
    "    Dataset.from_dict({\n",
    "        \"question\": questions,\n",
    "        \"ground_truth\": labels,\n",
    "        \"answer\": naive_rag_answers,\n",
    "    }),\n",
    "    metrics=[\n",
    "        # answer_relevancy,\n",
    "        answer_correctness,\n",
    "        answer_similarity,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|██████████| 200/200 [00:23<00:00,  8.59it/s]\n"
     ]
    }
   ],
   "source": [
    "local_graphrag_results = evaluate(\n",
    "    Dataset.from_dict({\n",
    "        \"question\": questions,\n",
    "        \"ground_truth\": labels,\n",
    "        \"answer\": local_graphrag_answers,\n",
    "    }),\n",
    "    metrics=[\n",
    "        # answer_relevancy,\n",
    "        answer_correctness,\n",
    "        answer_similarity,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Naive RAG results {'answer_correctness': 0.5896, 'answer_similarity': 0.8935}\n",
      "Local GraphRAG results {'answer_correctness': 0.7380, 'answer_similarity': 0.8619}\n"
     ]
    }
   ],
   "source": [
    "print(\"Naive RAG results\", naive_results)\n",
    "print(\"Local GraphRAG results\", local_graphrag_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "baai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
